import gym
import numpy as np
import pdb
from utils import powerset, symlog, symsqrt, sigmoid, tanh
import math, random
import copy

TIME_SMOOTHER = 10000

# assuming 1-1 between node-queue
# SA - server allocation
class SAQueue:
    def __init__(self, name, queue_info, random_starts = True, gridworld = False):
        self.name = name
        self.queue_info = queue_info
        self.random_starts = random_starts
        self.queue_init_high = 100
        self.gridworld = gridworld
        self.reset()

    def reset(self):
        if self.random_starts:
            if self.gridworld:
                self.num_jobs = np.random.randint(low = -self.queue_init_high, high = self.queue_init_high)
            else:
                self.num_jobs = np.random.randint(low = 0, high = self.queue_init_high)
        else:
            self.num_jobs = 0

    def get_arrival_prob(self, t):
        stationary = self.queue_info['arrival']['is_stationary']
        if stationary:
            return self.queue_info['arrival']['prob']
        else:
            return self._get_periodic_value('arrival', t)

    def get_service_prob(self, t):
        stationary = self.queue_info['service']['is_stationary']
        if stationary:
            return self.queue_info['service']['prob']
        else:
            return self._get_periodic_value('service', t)
    def get_connection_prob(self, t):
        stationary = self.queue_info['connection']['is_stationary']
        if stationary:
            return self.queue_info['connection']['prob']
        else:
            return self._get_periodic_value('connection', t)
    
    def _get_periodic_value(self, metric, t):
        typ = self.queue_info[metric]['type']
        if typ == 'piecewise':
            probs = self.queue_info[metric]['probs']
            pieces = len(self.queue_info[metric]['probs'])
            phase_length = self.queue_info[metric]['length']
            trans_length = self.queue_info[metric]['trans_length']
            period = pieces * (phase_length + trans_length)

            # hard-coded for two phases (excluding two transitions)
            t = (t % period)
            p_trans = phase_length + trans_length
            pr = 0
            if 0 <= t and t < phase_length:
                pr = probs[0]
            elif phase_length <= t and t < p_trans:
                slope = (probs[1] - probs[0]) / (trans_length)
                pr = probs[0] + slope * (t - phase_length)
            elif p_trans <= t and t < phase_length + p_trans:
                pr = probs[1]
            elif phase_length + p_trans <= t and t < 2 * p_trans:
                slope = (probs[0] - probs[1]) / (trans_length)
                pr = probs[1] + slope * (t - phase_length - p_trans)
            return pr
        elif typ == 'sine':
            period = self.queue_info[metric]['period']
            offset = self.queue_info[metric]['offset']
            multiplier = self.queue_info['arrival']['multiplier']
            periodic_offset = self.queue_info['arrival']['periodic_offset']
            offset = self.queue_info['arrival']['offset']
            return self._sigmoid_sine(t, period, multiplier, periodic_offset, offset)
    
    def _sigmoid_sine(self, t, period, multiplier = 1, periodic_offset = 0, offset = 0):
        t = t / TIME_SMOOTHER # simulating so form of continuous changes along curve
        val = multiplier * np.sin((2 * np.pi * t / period) + periodic_offset)
        return 1. / (1. + np.exp(-(val + offset))) 

class SANetwork(gym.Env):
    def __init__(self, queues, network_name = 'server_alloc', reward_func = 'avg-q-len',\
        r_mix_ratio = 1., state_trans = 'id', reward_transformation = 'id', use_mask = False,\
        reset_eps = 0, gridworld = False, opt_warmup_time = int(1e6), opt_beta = 4e-6, state_bound = np.inf):
        self.qs = queues
        self.reward_func = reward_func
        self.action_space = gym.spaces.Discrete(len(self.qs))
        # number of job counts (one per queue) + (0/1) 2D one-hot per queue
        dim = len(self.qs) + len(self.qs)# * 2
        if gridworld:
            self.lower_state_bound = -state_bound
            self.upper_state_bound = state_bound
            self.observation_space = gym.spaces.Box(low = self.lower_state_bound, high = self.upper_state_bound, shape = (dim,), dtype = float)
        else:
            self.lower_state_bound = 0
            self.upper_state_bound = state_bound
            self.observation_space = gym.spaces.Box(low = self.lower_state_bound, high = self.upper_state_bound, shape = (dim,), dtype = float)
        self.horizon = -1
        self.r_mix_ratio = r_mix_ratio
        self.state_trans = state_trans
        self.reward_transformation = reward_transformation
        self.use_mask = use_mask
        self.reset_eps = reset_eps
        self.t = 0
        self.gridworld = gridworld
        self.opt_warmup_time = opt_warmup_time
        self.opt_beta = opt_beta

    def set_reset_eps(self, eps):
        self.reset_eps = eps

    def set_horizon(self, horizon):
        self.horizon = horizon
    
    def set_state_transformation(self, trans):
        self.state_trans = trans

    def _get_period(self, q, param_name):
        if q.queue_info[param_name]['is_stationary']:
            period = 1
        else:
            if q.queue_info[param_name]['type'] == 'piecewise':
                pieces = len(q.queue_info[param_name]['probs'])
                phase_length = q.queue_info[param_name]['length']
                trans_length = q.queue_info[param_name]['trans_length']
                period = pieces * (phase_length + trans_length)
            else:
                period = q.queue_info[param_name]['period']
                if period < 1:
                    num_dec = str(period)[::-1].find('.')
                    period *= (10 ** num_dec)
                    period = int(period)
        return period

    # common multiple of periods considering for different parameters and queues
    def get_period_cm(self, q_sub):
        lam_prod = 1
        p_prod = 1
        c_prod = 1
        for q_idx in q_sub:
            q = self.qs[q_idx]
            lam_prod *= self._get_period(q, 'arrival')
            p_prod *= self._get_period(q, 'service')
            c_prod *= self._get_period(q, 'connection')
        return math.lcm(lam_prod, p_prod, c_prod)
        #return lam_prod * p_prod * c_prod * TIME_SMOOTHER # due to smoothing / 10 above

    def is_stable(self):

        ps = list(powerset(np.arange(len(self.qs))))
        min_gap = float('inf')
        # check every subset of queues
        for sub in ps:
            T = self.get_period_cm(sub)
            for t in np.arange(T):
                lam_p = 0
                c_prod = 1
                # for each subset apply stability criterion
                for q_idx in sub:
                    lam = self.qs[q_idx].get_arrival_prob(t)
                    p = self.qs[q_idx].get_service_prob(t)
                    c = self.qs[q_idx].get_connection_prob(t)
                    lam_p += (lam / p)
                    c_prod *= (1. - c)
                # if violated, break
                if lam_p >= 1. - c_prod:
                    pdb.set_trace()
                    return False
                min_gap = min(min_gap, (1. - c_prod) - lam_p)
        print ('distance from decision boundary {}'.format(min_gap))
        return True

    def transform_state(self, state):
        lens = state[:len(self.qs)]
        connects = state[len(self.qs):]
        if self.state_trans == 'symloge':
            adj_lens = symlog(lens, base = 'e')
            adj_state = np.concatenate((adj_lens, connects))
            return adj_state
        elif self.state_trans == 'symlog10':
            adj_lens = symlog(lens, base = '10')
            adj_state = np.concatenate((adj_lens, connects))
            return adj_state
        elif self.state_trans == 'symsqrt':
            adj_lens = symsqrt(lens)
            adj_state = np.concatenate((adj_lens, connects))
            return adj_state
        elif self.state_trans == 'sigmoid':
            adj_lens = sigmoid(lens)
            adj_state = np.concatenate((adj_lens, connects))
            return adj_state
        elif self.state_trans == 'tanh':
            adj_lens = tanh(lens)
            adj_state = np.concatenate((adj_lens, connects))
            return adj_state
        elif self.state_trans == 'id':
            return np.array(state)

    def reset(self, reset_time = True):
        lens = []
        if reset_time:
            self.t = 0
        for q in self.qs:
            q.reset()
            lens.append(q.num_jobs)
        lens = np.array(lens)
        #lens = np.clip(lens, 0, 60)

        # add connectivity into state
        connect_probs = [self.qs[idx].get_connection_prob(0) for idx in range(len(self.qs))]
        self.connects = np.random.binomial(n = 1, p = connect_probs)
        # one_hots = []
        # for c in self.connects:
        #     base = [0, 0]
        #     base[c] = 1.
        #     one_hots += base
        # self.connects = np.array(one_hots)
        self.pre_arrival_lens = np.concatenate((lens, self.connects))
        self.native_state = np.concatenate((lens, self.connects))
        self.native_prev_state = np.concatenate((lens, self.connects))

        self.state = self.transform_state(self.native_state)
        self.prev_state = self.transform_state(self.native_prev_state)
        self.w_temp = np.tanh(4e-6 * np.maximum(self.t - 1000000, 0.01))
        #if self.gridworld:
        #    self.goal = np.random.randint(low = -100, high = 100, size = len(self.qs))
        #else:
        self.goal = np.array([0 for _ in range(len(self.qs))])
        print ('init state: {}, goal state: {}'.format(self.native_state, self.goal))
        return self.state, {}

    def _overall_avg_backlog(self, state):
        return np.mean(np.abs(state[:len(self.qs)] - self.goal))

    def _avg_backlog(self, state, next_state, weighted = False):
        lens = np.abs(next_state[:len(self.qs)] - self.goal)
        if weighted:
            su = np.sum(lens)
            if su == 0:
                su = 1.
            weights = lens / su
            met = np.sum(weights * lens)
        else:
            met = np.mean(lens)
        return met

    def _backlog_change(self, state, next_state):
        prev_lengths = np.mean(np.abs(state[:len(self.qs)] - self.goal))
        curr_lengths = np.mean(np.abs(next_state[:len(self.qs)] - self.goal)) # Manhatten distance
        metric = curr_lengths - prev_lengths
        return metric 

    def reward_function(self, state = None, action = None, next_state = None, pre_arrial_state = None):

        avg_q_len = self._avg_backlog(state, next_state)
        w_avg_q_len = self._avg_backlog(state, next_state, weighted = True)
        change_avg_q_len = self._backlog_change(state, next_state)

        if self.reward_func == 'avg-q-len':
            reward = -1 * avg_q_len
        elif self.reward_func == 'reci-avg-q-len':
            reward = 1. / (avg_q_len + 1)
        elif self.reward_func == 'mix-avg-q-len-and-change':
            neg_change_avg_q_len = -1 * change_avg_q_len
            #reward = self.r_mix_ratio * (1. / (met1 + 1)) + (1. - self.r_mix_ratio) * met2
            self.w_temp = np.tanh(self.opt_beta * np.maximum(self.t - self.opt_warmup_time, 0.01))
            reward = self.w_temp * (1. / (avg_q_len + 1)) + neg_change_avg_q_len
            #reward = self.w_temp * (1. / (symlog(avg_q_len, base = 'e') + 1)) + neg_change_avg_q_len
            #reward = self.w_temp * -avg_q_len + neg_change_avg_q_len
        elif self.reward_func == 'avg-q-len-change':
            reward = -1 * change_avg_q_len
        elif self.reward_func == 'avg-sq-q-len-change':
            prev_lens = np.mean(np.square(state[:len(self.qs)] - self.goal))
            curr_lens = np.mean(np.square(next_state[:len(self.qs)] - self.goal))
            reward = -1 * (curr_lens - prev_lens)
        elif self.reward_func == 'mix-sq-avg-q-len-and-change':
            prev_lens = np.mean(np.square(state[:len(self.qs)] - self.goal))
            curr_lens = np.mean(np.square(next_state[:len(self.qs)] - self.goal))
            reward = -1 * (curr_lens - prev_lens) - avg_q_len
        elif self.reward_func == 'w-avg-q-len':
            reward = -1 * w_avg_q_len
        elif self.reward_func == 'mix-w-avg-q-len-and-change':
            neg_change_avg_q_len = -1 * change_avg_q_len
            #reward = self.r_mix_ratio * (1. / (met1 + 1)) + (1. - self.r_mix_ratio) * met2
            self.w_temp = np.tanh(self.opt_beta * np.maximum(self.t - self.opt_warmup_time, 0.01))
            reward = self.w_temp * -w_avg_q_len + neg_change_avg_q_len
        return reward

    def step(self, a):
        assert self.action_space.contains(a)
        #assert self.observation_space.contains(self.native_state)

        # policy takes an action based on the state
        # now the next step is dependent on the arrival time and service time
        # for example, if we have two queues, there are 8 possible next states given
        # the current state and action
        # new jobs in each queue may or may not arrive, and selected job may or may not
        # be serviced

        # record previous state
        self.prev_state = self.state
        self.native_prev_state = self.native_state

        # service rates of queues
        service_probs = [self.qs[idx].get_service_prob(self.t) for idx in range(len(self.qs))]
        service_success = np.random.binomial(n = 1, p = service_probs)

        # connectivity to queue
        is_connected = self.connects[a]
        #is_connected = np.split(self.connects, len(self.connects) / 2)[a][1]
        # service job
        if is_connected:
            serve_success = service_success[a]
        else:
            # job will not be served if server cannot connect to queue
            serve_success = 0
        curr_num_jobs = self.qs[a].num_jobs

        if serve_success:
            current_pos = self.qs[a].num_jobs
            # if at goal, nothing happens
            if current_pos > self.goal[a]:
                self.qs[a].num_jobs = current_pos - 1
            elif current_pos < self.goal[a]:
                self.qs[a].num_jobs = current_pos + 1
            
            # trim if negatives not allowed
            if not self.gridworld:
                self.qs[a].num_jobs = max(self.qs[a].num_jobs, 0)

        pre_arrival_lens = np.array([q.num_jobs for q in self.qs])
        pre_arrival_lens = np.clip(pre_arrival_lens, self.lower_state_bound, self.upper_state_bound)

        # arrival of new jobs
        new_lengths = []
        arrival_probs = [self.qs[idx].get_arrival_prob(self.t) for idx in range(len(self.qs))]
        arrive_success = np.random.binomial(n = 1, p = arrival_probs)
        for idx, q in enumerate(self.qs):
            current_pos = q.num_jobs
            # always holds true if above trimming done
            if current_pos >= self.goal[a]:
                q.num_jobs = current_pos + arrive_success[idx]
            else:
                q.num_jobs = current_pos - arrive_success[idx]
            
            if not self.gridworld:
                q.num_jobs = max(q.num_jobs, 0)
            new_lengths.append(q.num_jobs)
        new_lengths = np.array(new_lengths)
        new_lengths = np.clip(new_lengths, self.lower_state_bound, self.upper_state_bound)

        # connectivity of the new state
        connect_probs = [self.qs[idx].get_connection_prob(self.t) for idx in range(len(self.qs))]
        self.connects = np.random.binomial(n = 1, p = connect_probs)
        
        # one_hots = []
        # for c in self.connects:
        #     base = [0, 0]
        #     base[c] = 1.
        #     one_hots += base
        # self.connects = np.array(one_hots)
        self.pre_arrival_lens = np.concatenate((pre_arrival_lens, self.connects))
        self.native_state = np.concatenate((new_lengths, self.connects))
        self.state = self.transform_state(self.native_state)

        backlog = self._overall_avg_backlog(state = self.native_state)
        # backlog_change = self._get_metric('avg-q-len-change',
        #     state = self.native_prev_state,
        #     next_state = self.native_state)

        # if based on change, then need updated state to compute reward
        reward = self.reward_function(self.native_prev_state, a, self.native_state, self.pre_arrival_lens)

        # reset_episode = np.random.binomial(n = 1, p = self.reset_eps)
        # if reset_episode:
        #    self.reset(reset_time = False)

        self.t += 1
        done = False
        if self.horizon != -1:    
            done = True if self.t >= self.horizon else False

        # cost may be either of metrics
        info = {
            'backlog': backlog,
            # 'backlog_change': backlog_change,
            'time': self.t,
            'services': service_success,
            'arrivals': arrive_success,
            'native_state': self.native_prev_state,
            'next_native_state': self.native_state,
            'connects': self.connects,
            'action': a
        }
        #assert not done
        return self.state, reward, done, False, info


    def _circle_sample(self, num_points, x, y, radius):
        points = []
        for _ in range(num_points):
            # random angle
            alpha = 2 * math.pi * random.random()
            # random radius
            r = radius * math.sqrt(random.random())
            # calculating coordinates
            x = r * math.cos(alpha) + x
            y = r * math.sin(alpha) + y
            points.append([max(x, 0), max(y, 0)])
        return points

    def augment_samples(self, samples):
        num_actions = self.action_space.n
        num_new_curr_states = 5
        new_samples = []
        for s in samples:
            # generate new states away from me
            lens = s[0][:len(self.qs)]
            new_curr_lens = np.array(self._circle_sample(num_new_curr_states, lens[0], lens[1], 50)).astype(int)
            #new_curr_lens = np.random.uniform(low = lens,\
            #    high = lens + 50, size = (num_new_curr_states, len(lens))).astype(int)

            # extract exogenous and endogenous variables info
            connects = s[0][len(self.qs):]
            cons = np.split(connects, len(connects) / 2)
            services = s[1]
            arrivals = s[2]
            new_connects = s[3]

            for new_curr_len in new_curr_lens:
                for a in range(num_actions):
                    new_next_len = copy.deepcopy(new_curr_len)
                    le = new_curr_len[a]
                    is_connected = cons[a][1]
                    is_serviced = services[a]
                    if is_connected and is_serviced:
                        le = max(le - 1, 0)
                    new_next_len[a] = le

                    # count all new arrivals
                    for idx, ar in enumerate(arrivals):
                        new_next_len[idx] += ar
                    new_curr_state_full = np.concatenate((np.array(new_curr_len), connects))
                    new_next_state_full = np.concatenate((np.array(new_next_len), new_connects))
                    rew = self.reward_function(new_curr_state_full, a, new_next_state_full)
                    #print (new_curr_state_full, a, new_next_state_full, rew, services[a], arrivals[a])
                    new_sample = (self.transform_state(new_curr_state_full), a, self.transform_state(new_next_state_full), rew)
                    new_samples.append(new_sample)
        return new_samples

    def set_use_mask(self, flag):
        self.use_mask = flag

    def set_state_bound(self, bound):
        if self.gridworld:
            self.lower_state_bound = -bound
            self.upper_state_bound = bound
        else:
            self.lower_state_bound = 0
            self.upper_state_bound = bound
    
    def mask_extractor(self, obs):
        # iterate for each queue but vecotrized across examples
        obs_dim = self.observation_space.shape[0]
        masks = np.zeros((obs.shape[0], self.action_space.n)).astype(bool)
        if self.use_mask:
            obs = obs[:, -obs_dim:]
            lens = obs[:, :len(self.qs)]
            cons = obs[:, len(self.qs):]
            for i in range(len(self.qs)):
                masks[:, i] = np.logical_or(lens[:, i] == 0, cons[:, i] == 0) # if either empty or disconnected set mask on
                #masks[:, i] = (np.logical_or(lens[:, i] == 0, cons[:, 2 * i + 1] == 0)) # if either empty or disconnected set mask on
        return masks
